{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/Denys88/rl_games/blob/master/notebooks/brax_training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "JbZsMYmyZiVr"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/Denys88/rl_games\n",
      "  Cloning https://github.com/Denys88/rl_games to /tmp/pip-req-build-mr7os74g\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/Denys88/rl_games /tmp/pip-req-build-mr7os74g\n",
      "  Resolved https://github.com/Denys88/rl_games to commit 42c076edaf071e7f5a5f9154e1f3c1302c038aba\n",
      "  Installing build dependencies ... \u001b[?25ldone\n",
      "\u001b[?25h  Getting requirements to build wheel ... \u001b[?25ldone\n",
      "\u001b[?25h  Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: PyYAML<7.0,>=6.0 in /home/viktor4090/anaconda3/envs/jax312/lib/python3.12/site-packages (from rl_games==1.6.1) (6.0.2)\n",
      "Collecting gym<0.24.0,>=0.23.0 (from gym[classic-control]<0.24.0,>=0.23.0->rl_games==1.6.1)\n",
      "  Using cached gym-0.23.1.tar.gz (626 kB)\n",
      "  Installing build dependencies ... \u001b[?25ldone\n",
      "\u001b[?25h  Getting requirements to build wheel ... \u001b[?25ldone\n",
      "\u001b[?25h  Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
      "\u001b[?25hCollecting opencv-python<5.0.0,>=4.5.5 (from rl_games==1.6.1)\n",
      "  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n",
      "Collecting psutil<6.0.0,>=5.9.0 (from rl_games==1.6.1)\n",
      "  Using cached psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
      "Requirement already satisfied: setproctitle<2.0.0,>=1.2.2 in /home/viktor4090/anaconda3/envs/jax312/lib/python3.12/site-packages (from rl_games==1.6.1) (1.3.4)\n",
      "Requirement already satisfied: tensorboard<3.0.0,>=2.8.0 in /home/viktor4090/anaconda3/envs/jax312/lib/python3.12/site-packages (from rl_games==1.6.1) (2.18.0)\n",
      "Requirement already satisfied: tensorboardX<3.0,>=2.5 in /home/viktor4090/anaconda3/envs/jax312/lib/python3.12/site-packages (from rl_games==1.6.1) (2.6.2.2)\n",
      "Collecting wandb<0.13.0,>=0.12.11 (from rl_games==1.6.1)\n",
      "  Using cached wandb-0.12.21-py2.py3-none-any.whl.metadata (7.2 kB)\n",
      "INFO: pip is looking at multiple versions of rl-games to determine which version is compatible with other requirements. This could take a while.\n",
      "\u001b[31mERROR: Package 'rl-games' requires a different Python: 3.12.5 not in '<3.11,>=3.7.1'\u001b[0m\u001b[31m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install git+https://github.com/Denys88/rl_games"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Mq4PwGLn13Wm"
   },
   "outputs": [],
   "source": [
    "#@title Brax training example\n",
    "#@markdown ## ⚠️ PLEASE NOTE:\n",
    "#@markdown This colab runs using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.\n",
    "\n",
    "from datetime import datetime\n",
    "import functools\n",
    "import os\n",
    "\n",
    "from IPython.display import HTML, clear_output\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "try:\n",
    "  import brax\n",
    "except ImportError:\n",
    "  !pip install git+https://github.com/google/brax.git@main\n",
    "  clear_output()\n",
    "  import brax\n",
    "\n",
    "from brax import envs\n",
    "from brax.io import html\n",
    "from brax.io import model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DnUgCbN3OwLB",
    "outputId": "e266fbc5-06ad-4679-e208-c5b485acaadd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU 0: NVIDIA GeForce RTX 4090 Laptop GPU (UUID: GPU-53f9f624-49f2-daaf-3d97-93c5ff2684f6)\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi -L"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "6qvHCGgpxrvZ"
   },
   "outputs": [],
   "source": [
    "%load_ext tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-4fb9dd5b6858e6d7\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-4fb9dd5b6858e6d7\");\n",
       "          const url = new URL(\"/\", window.location);\n",
       "          const port = 6006;\n",
       "          if (port) {\n",
       "            url.port = port;\n",
       "          }\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%tensorboard --logdir 'runs/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "5Zym5GEjw_iE"
   },
   "outputs": [],
   "source": [
    "## ant brax config:\n",
    "ant_config = {'params': {'algo': {'name': 'a2c_continuous'},\n",
    "  'config': {'bound_loss_type': 'regularisation',\n",
    "   'bounds_loss_coef': 0.0,\n",
    "   'clip_value': True,\n",
    "   'critic_coef': 4,\n",
    "   'e_clip': 0.2,\n",
    "   'entropy_coef': 0.0,\n",
    "   'env_config': {'env_name': 'ant', 'seed': 5},\n",
    "   'env_name': 'brax',\n",
    "   'gamma': 0.99,\n",
    "   'grad_norm': 1.0,\n",
    "   'horizon_length': 8,\n",
    "   'kl_threshold': 0.008,\n",
    "   'learning_rate': '3e-4',\n",
    "   'lr_schedule': 'adaptive',\n",
    "   'max_epochs': 5000,\n",
    "   'mini_epochs': 4,\n",
    "   'minibatch_size': 32768,\n",
    "   'name': 'ant-brax',\n",
    "   'normalize_advantage': True,\n",
    "   'normalize_input': True,\n",
    "   'normalize_value': True,\n",
    "   'num_actors': 4096,\n",
    "   'player': {'render': True},\n",
    "   'ppo': True,\n",
    "   'reward_shaper': {'scale_value': 0.1},\n",
    "   'schedule_type': 'standard',\n",
    "   'score_to_win': 20000,\n",
    "   'tau': 0.95,\n",
    "   'truncate_grads': True,\n",
    "   'use_smooth_clamp': True,\n",
    "   'value_bootstrap': True},\n",
    "  'model': {'name': 'continuous_a2c_logstd'},\n",
    "  'network': {'mlp': {'activation': 'elu',\n",
    "    'initializer': {'name': 'default'},\n",
    "    'units': [256, 128, 64]},\n",
    "   'name': 'actor_critic',\n",
    "   'separate': False,\n",
    "   'space': {'continuous': {'fixed_sigma': True,\n",
    "     'mu_activation': 'None',\n",
    "     'mu_init': {'name': 'default'},\n",
    "     'sigma_activation': 'None',\n",
    "     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},\n",
    "  'seed': 5}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "dt2q0HgmxDrb"
   },
   "outputs": [],
   "source": [
    "## config from the openai gym mujoco (should have the same network and normalization) to render result:\n",
    "humanoid_config = {'params': {'algo': {'name': 'a2c_continuous'},\n",
    "  'config': {'bound_loss_type': 'regularisation',\n",
    "   'bounds_loss_coef': 0.0,\n",
    "   'clip_value': True,\n",
    "   'critic_coef': 4,\n",
    "   'e_clip': 0.2,\n",
    "   'entropy_coef': 0.0,\n",
    "   'env_config': {'env_name': 'humanoid', 'seed': 5},\n",
    "   'env_name': 'brax',\n",
    "   'gamma': 0.99,\n",
    "   'grad_norm': 1.0,\n",
    "   'horizon_length': 16,\n",
    "   'kl_threshold': 0.008,\n",
    "   'learning_rate': '3e-4',\n",
    "   'lr_schedule': 'adaptive',\n",
    "   'max_epochs': 5000,\n",
    "   'mini_epochs': 5,\n",
    "   'minibatch_size': 32768,\n",
    "   'name': 'humanoid-brax',\n",
    "   'normalize_advantage': True,\n",
    "   'normalize_input': True,\n",
    "   'normalize_value': True,\n",
    "   'num_actors': 4096,\n",
    "   'player': {'render': True},\n",
    "   'ppo': True,\n",
    "   'reward_shaper': {'scale_value': 0.1},\n",
    "   'schedule_type': 'standard',\n",
    "   'score_to_win': 20000,\n",
    "   'tau': 0.95,\n",
    "   'truncate_grads': True,\n",
    "   'use_smooth_clamp': True,\n",
    "   'value_bootstrap': True},\n",
    "  'model': {'name': 'continuous_a2c_logstd'},\n",
    "  'network': {'mlp': {'activation': 'elu',\n",
    "    'initializer': {'name': 'default'},\n",
    "    'units': [512, 256, 128]},\n",
    "   'name': 'actor_critic',\n",
    "   'separate': False,\n",
    "   'space': {'continuous': {'fixed_sigma': True,\n",
    "     'mu_activation': 'None',\n",
    "     'mu_init': {'name': 'default'},\n",
    "     'sigma_activation': 'None',\n",
    "     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},\n",
    "  'seed': 5}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "swX1oGIQavbI",
    "outputId": "5489fcfb-4426-499d-909e-5b9afa7546dd"
   },
   "outputs": [],
   "source": [
    "import yaml\n",
    "from rl_games.torch_runner import Runner\n",
    "\n",
    "env_name = 'ant'  # @param ['ant', 'humanoid']\n",
    "configs = {\n",
    "    'ant' : ant_config,\n",
    "    'humanoid' : humanoid_config\n",
    "}\n",
    "networks = {\n",
    "    'ant' : 'runs/ant/nn/ant-brax.pth',\n",
    "    'humanoid' : 'runs/humanoid/nn/humanoid-brax.pth'\n",
    "}\n",
    "\n",
    "config = configs[env_name]\n",
    "network_path = networks[env_name]\n",
    "config['params']['config']['full_experiment_name'] = env_name\n",
    "config['params']['config']['max_epochs'] = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "self.seed = 5\n",
      "Started to train\n",
      "Exact experiment name requested from command line: ant\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "BraxEnv.__init__() got multiple values for argument 'env_name'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[9], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m runner \u001b[38;5;241m=\u001b[39m Runner()\n\u001b[1;32m      2\u001b[0m runner\u001b[38;5;241m.\u001b[39mload(config)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mrunner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m      5\u001b[0m \u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/torch_runner.py:178\u001b[0m, in \u001b[0;36mRunner.run\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m    171\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Run either train/play depending on the args.\u001b[39;00m\n\u001b[1;32m    172\u001b[0m \n\u001b[1;32m    173\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[1;32m    174\u001b[0m \u001b[38;5;124;03m    args (:obj:`dict`):  Args passed in as a dict obtained from a yaml file or some other config format.\u001b[39;00m\n\u001b[1;32m    175\u001b[0m \n\u001b[1;32m    176\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m        \n\u001b[1;32m    177\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m]:\n\u001b[0;32m--> 178\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    179\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mplay\u001b[39m\u001b[38;5;124m'\u001b[39m]:\n\u001b[1;32m    180\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_play(args)\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/torch_runner.py:146\u001b[0m, in \u001b[0;36mRunner.run_train\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Run the training procedure from the algorithm passed in.\u001b[39;00m\n\u001b[1;32m    140\u001b[0m \n\u001b[1;32m    141\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[1;32m    142\u001b[0m \u001b[38;5;124;03m    args (:obj:`dict`): Train specific args passed in as a dict obtained from a yaml file or some other config format.\u001b[39;00m\n\u001b[1;32m    143\u001b[0m \n\u001b[1;32m    144\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    145\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mStarted to train\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 146\u001b[0m agent \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malgo_factory\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malgo_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbase_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrun\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    147\u001b[0m _restore(agent, args)\n\u001b[1;32m    148\u001b[0m _override_sigma(agent, args)\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/common/object_factory.py:40\u001b[0m, in \u001b[0;36mObjectFactory.create\u001b[0;34m(self, name, **kwargs)\u001b[0m\n\u001b[1;32m     38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m builder:\n\u001b[1;32m     39\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(name)\n\u001b[0;32m---> 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbuilder\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/torch_runner.py:57\u001b[0m, in \u001b[0;36mRunner.__init__.<locals>.<lambda>\u001b[0;34m(**kwargs)\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Initialise the runner instance with algorithms and observers.\u001b[39;00m\n\u001b[1;32m     47\u001b[0m \n\u001b[1;32m     48\u001b[0m \u001b[38;5;124;03mInitialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory`\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     53\u001b[0m \n\u001b[1;32m     54\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39malgo_factory \u001b[38;5;241m=\u001b[39m object_factory\u001b[38;5;241m.\u001b[39mObjectFactory()\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39malgo_factory\u001b[38;5;241m.\u001b[39mregister_builder(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma2c_continuous\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs : \u001b[43ma2c_continuous\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mA2CAgent\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39malgo_factory\u001b[38;5;241m.\u001b[39mregister_builder(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma2c_discrete\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs : a2c_discrete\u001b[38;5;241m.\u001b[39mDiscreteA2CAgent(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)) \n\u001b[1;32m     59\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39malgo_factory\u001b[38;5;241m.\u001b[39mregister_builder(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msac\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: sac_agent\u001b[38;5;241m.\u001b[39mSACAgent(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs))\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/algos_torch/a2c_continuous.py:27\u001b[0m, in \u001b[0;36mA2CAgent.__init__\u001b[0;34m(self, base_name, params)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, base_name, params):\n\u001b[1;32m     19\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Initialise the algorithm with passed params\u001b[39;00m\n\u001b[1;32m     20\u001b[0m \n\u001b[1;32m     21\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     24\u001b[0m \n\u001b[1;32m     25\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m     \u001b[43ma2c_common\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mContinuousA2CBase\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbase_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     28\u001b[0m     obs_shape \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobs_shape\n\u001b[1;32m     29\u001b[0m     build_config \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m     30\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mactions_num\u001b[39m\u001b[38;5;124m'\u001b[39m : \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactions_num,\n\u001b[1;32m     31\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_shape\u001b[39m\u001b[38;5;124m'\u001b[39m : obs_shape,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     35\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnormalize_input\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnormalize_input,\n\u001b[1;32m     36\u001b[0m     }\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/common/a2c_common.py:1156\u001b[0m, in \u001b[0;36mContinuousA2CBase.__init__\u001b[0;34m(self, base_name, params)\u001b[0m\n\u001b[1;32m   1155\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, base_name, params):\n\u001b[0;32m-> 1156\u001b[0m     \u001b[43mA2CBase\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbase_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1158\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_discrete \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m   1159\u001b[0m     action_space \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maction_space\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/common/a2c_common.py:129\u001b[0m, in \u001b[0;36mA2CBase.__init__\u001b[0;34m(self, base_name, params)\u001b[0m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_info \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124menv_info\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_info \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 129\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvec_env \u001b[38;5;241m=\u001b[39m \u001b[43mvecenv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_vec_env\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_actors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    130\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvec_env\u001b[38;5;241m.\u001b[39mget_env_info()\n\u001b[1;32m    131\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/common/vecenv.py:274\u001b[0m, in \u001b[0;36mcreate_vec_env\u001b[0;34m(config_name, num_actors, **kwargs)\u001b[0m\n\u001b[1;32m    272\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_vec_env\u001b[39m(config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    273\u001b[0m     vec_env_name \u001b[38;5;241m=\u001b[39m configurations[config_name][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvecenv_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m--> 274\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mvecenv_config\u001b[49m\u001b[43m[\u001b[49m\u001b[43mvec_env_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_actors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Projects/rl_games/rl_games/common/vecenv.py:280\u001b[0m, in \u001b[0;36m<lambda>\u001b[0;34m(config_name, num_actors, **kwargs)\u001b[0m\n\u001b[1;32m    277\u001b[0m register(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRAY\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: RayVecEnv(config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs))\n\u001b[1;32m    279\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrl_games\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menvs\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbrax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m BraxEnv\n\u001b[0;32m--> 280\u001b[0m register(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBRAX\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[43mBraxEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_actors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    282\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrl_games\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menvs\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menvpool\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Envpool\n\u001b[1;32m    283\u001b[0m register(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mENVPOOL\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Envpool(config_name, num_actors, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs))\n",
      "\u001b[0;31mTypeError\u001b[0m: BraxEnv.__init__() got multiple values for argument 'env_name'"
     ]
    }
   ],
   "source": [
    "runner = Runner()\n",
    "runner.load(config)\n",
    "runner.run({\n",
    "    'train': True,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3s-95B-KlqE1"
   },
   "outputs": [],
   "source": [
    "from rl_games.envs.brax import BraxEnv\n",
    "\n",
    "from IPython.display import HTML, IFrame, display, clear_output\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EikAyoGpmDpl",
    "outputId": "dd2006b1-e02a-4907-cf79-1b7d98d657e0"
   },
   "outputs": [],
   "source": [
    "agent = runner.create_player()\n",
    "agent.restore(network_path)\n",
    "\n",
    "env_config = runner.params['config']['env_config']\n",
    "num_actors = 1\n",
    "env = BraxEnv('', num_actors, **env_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "L6GRBDhsubx6",
    "outputId": "75ff7fb9-d19b-4b3d-91ca-132dc3e557a6"
   },
   "outputs": [],
   "source": [
    "qps = []\n",
    "obs = env.reset()\n",
    "total_reward = 0\n",
    "num_steps = 0\n",
    "\n",
    "class QP:\n",
    "    def __init__(self, qp):\n",
    "        self.pos = jax.numpy.squeeze(qp.pos, axis=0)\n",
    "        self.rot = jax.numpy.squeeze(qp.rot, axis=0)\n",
    "\n",
    "is_done = False\n",
    "while not is_done:\n",
    "    qps.append(QP(env.env._state.qp))\n",
    "    act = agent.get_action(obs)\n",
    "    obs, reward, is_done, info = env.step(act.unsqueeze(0))\n",
    "    total_reward += reward.item()\n",
    "    num_steps += 1\n",
    "\n",
    "print('Total Reward: ', total_reward)\n",
    "print('Num steps: ', num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-iM6S8V55i5a"
   },
   "outputs": [],
   "source": [
    "def visualize(sys, qps):\n",
    "    return HTML(html.render(sys, qps))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "id": "7xoJWkUI5kbF",
    "outputId": "59be7c1b-2700-44ff-9b90-b5ba1aba80e7"
   },
   "outputs": [],
   "source": [
    "display(visualize(env.env._env.sys, qps))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "rlg_colab_brax.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "jax312",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
